
package com.jstarcraft.ai.jsat.distributions;

import static com.jstarcraft.ai.jsat.math.SpecialMath.betaIncReg;
import static com.jstarcraft.ai.jsat.math.SpecialMath.invBetaIncReg;
import static com.jstarcraft.ai.jsat.math.SpecialMath.lnGamma;
import static java.lang.Math.PI;
import static java.lang.Math.exp;
import static java.lang.Math.log;
import static java.lang.Math.pow;
import static java.lang.Math.sqrt;

import java.util.Random;

import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.text.GreekLetters;

/**
 * 
 * @author Edward Raff
 */
public class StudentT extends ContinuousDistribution {

    private static final long serialVersionUID = -3157525461647767831L;
    double df;
    double mu;
    double sig;

    public StudentT(double df) {
        this(df, 0, 1);
    }

    public StudentT(double df, double mu, double sig) {
        this.df = df;
        this.mu = mu;
        this.sig = sig;
    }

    /**
     * Sets the degrees of freedom used by the test.
     * 
     * @param df the new value for the degrees of freedom.
     */
    public void setDf(double df) {
        this.df = df;
    }

    @Override
    public double pdf(double t) {

        double leftSide = lnGamma((df + 1) / 2) - lnGamma(df / 2) - lnGamma(df * PI) / 2 - log(sig);
        double rightSide = -(df + 1) / 2 * log(1 + pow((t - mu) / sig, 2) / df);

        return exp(leftSide + rightSide);
    }

    @Override
    public double cdf(double t) {
        double x = df / (df + pow((t - mu) / sig, 2));

        double p = betaIncReg(x, df / 2, 0.5) / 2;

        if (t > mu)
            return 1 - p;
        else
            return p;
    }

    @Override
    public double invCdf(double p) {
        if (p < 0 || p > 1)
            throw new ArithmeticException("Probability must be in the range [0,1], not " + p);
        double x = invBetaIncReg(2 * Math.min(p, 1 - p), df / 2, 0.5);
        x = sig * sqrt(df * (1 - x) / x);

        if (p >= 0.5)
            return mu + x;
        else
            return mu - x;
    }

    @Override
    public double min() {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double max() {
        return Double.POSITIVE_INFINITY;
    }

    @Override
    public String getDescriptiveName() {
        return "Student-T(df=" + df + ", \u03BC=" + mu + ", \u03C3=" + sig + ")";
    }

    @Override
    public String getDistributionName() {
        return "Student-T";
    }

    @Override
    public String[] getVariables() {
        return new String[] { "df", GreekLetters.mu, GreekLetters.sigma };
    }

    @Override
    public double[] getCurrentVariableValues() {
        return new double[] { df, mu, sig };
    }

    @Override
    public void setVariable(String var, double value) {
        if (var.equals("df"))
            if (value > 0)
                df = value;
            else
                throw new ArithmeticException("Degrees of Fredom must be greater than 0");
        else if (var.equals(GreekLetters.mu))
            mu = value;
        else if (var.equals(GreekLetters.sigma))
            if (value > 0)
                sig = value;
            else
                throw new ArithmeticException("Standard deviation must be greater than zero");

    }

    @Override
    public ContinuousDistribution clone() {
        return new StudentT(df, mu, sig);
    }

    @Override
    public void setUsingData(Vec data) {
        /*
         * While not true in every use of the t-distribution, we assume degrees of
         * fredom is n-1 if n is the number of samples
         * 
         */
        df = data.length() - 1;
        mu = data.mean();
        sig = sqrt(data.variance() * df / (df - 2));
    }

    @Override
    public double mean() {
        return mu;
    }

    @Override
    public double median() {
        return mu;
    }

    @Override
    public double mode() {
        return mu;
    }

    @Override
    public double variance() {
        if (df <= 1)
            return Double.NaN;
        else if (1 < df && df <= 2)
            return Double.POSITIVE_INFINITY;
        return df / (df - 2) * sig * sig;
    }

    @Override
    public double skewness() {
        if (df <= 3)// Undefined for df <= 3
            return Double.NaN;
        return 0;
    }

    @Override
    public double[] sample(int numSamples, Random rand) {
        if (mu != 0 || sig != 1)
            return super.sample(numSamples, rand);
        double[] sample = new ChiSquared(df).sample(numSamples, rand);
        for (int i = 0; i < sample.length; i++)
            sample[i] = rand.nextGaussian() / sqrt(sample[i] / df);
        return sample;
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        long temp;
        temp = Double.doubleToLongBits(df);
        result = prime * result + (int) (temp ^ (temp >>> 32));
        temp = Double.doubleToLongBits(mu);
        result = prime * result + (int) (temp ^ (temp >>> 32));
        temp = Double.doubleToLongBits(sig);
        result = prime * result + (int) (temp ^ (temp >>> 32));
        return result;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null) {
            return false;
        }
        if (getClass() != obj.getClass()) {
            return false;
        }
        StudentT other = (StudentT) obj;
        if (Double.doubleToLongBits(df) != Double.doubleToLongBits(other.df)) {
            return false;
        }
        if (Double.doubleToLongBits(mu) != Double.doubleToLongBits(other.mu)) {
            return false;
        }
        if (Double.doubleToLongBits(sig) != Double.doubleToLongBits(other.sig)) {
            return false;
        }
        return true;
    }

}
